{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SDEExiAk4fLb"
      },
      "source": [
        "# Fine-tune Gemma in Keras using LoRA"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZFWzQEqNosrS"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/lora_tuning\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/lora_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/core/lora_tuning.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/lora_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lSGRSsRPgkzK"
      },
      "source": [
        "Generative artificial intelligent (AI) models like Gemma are effective at a variety of tasks. You can further fine-tune Gemma models with domain-specific data to perform tasks such as sentiment analysis. However, full fine-tuning of generative models by updating billions of parameters is resource intensive, requiring specialized hardware, such as GPUs, processing time, and memory to load the model parameters.\n",
        "\n",
        "[Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This technique makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs. This tutorial walks you through using Keras to perform LoRA fine-tuning on a Gemma model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lyhHCMfoRZ_v"
      },
      "source": [
        "## Setup\n",
        "\n",
        "To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
        "\n",
        "* Get access to Gemma on [kaggle.com](https://kaggle.com).\n",
        "* Select a Colab runtime with sufficient resources to tune\n",
        "  the Gemma model you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).\n",
        "* Generate and configure a Kaggle username and API key.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AZ5Qo0fxRZ1V"
      },
      "source": [
        "### Select a Colab runtime\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select &#9662; (**Additional connection options**).\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hsPC0HRkJl0K"
      },
      "source": [
        "### Configure your API key\n",
        "\n",
        "To use Gemma, you must provide your Kaggle username and a Kaggle API key.\n",
        "\n",
        "To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
        "\n",
        "In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7iOF6Yo-wUEC"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0_EdOg9DPK6Q"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CuEUAKJW1QkQ"
      },
      "source": [
        "### Install Keras packages\n",
        "\n",
        "Install the Keras and KerasHub Python packages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1eeBtYqJsZPG"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U keras-hub\n",
        "!pip install  -q -U keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rGLS-l5TxIR4"
      },
      "source": [
        "### Select a backend\n",
        "\n",
        "Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch. For this tutorial, configure the backend for JAX as it typically provides the better performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yn5uy8X8sdD0"
      },
      "outputs": [],
      "source": [
        "os.environ[\"KERAS_BACKEND\"] = \"jax\"  # Or \"torch\" or \"tensorflow\".\n",
        "# Avoid memory fragmentation on JAX backend.\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hZs8XXqUKRmi"
      },
      "source": [
        "### Import packages\n",
        "\n",
        "Import the Python packages needed for this tutorial, including Keras and KerasHub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FYHyPUA9hKTf"
      },
      "outputs": [],
      "source": [
        "import keras\n",
        "import keras_hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7RCE3fdGhDE5"
      },
      "source": [
        "## Load model\n",
        "\n",
        "Keras provides implementations of Gemma and many other popular [model architectures](https://keras.io/keras_hub/api/models/). Use the `Gemma3CausalLM.from_preset()` method to configure an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vz5zLEyLstfn"
      },
      "outputs": [],
      "source": [
        "gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(\"gemma3_instruct_1b\")\n",
        "gemma_lm.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nl4lvPy5zA26"
      },
      "source": [
        "The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `\"gemma#_xxxxxxx\"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G_L6A5J-1QgC"
      },
      "source": [
        "## Inference before fine tuning\n",
        "\n",
        "Once you have downloaded and configured a Gemma model, you can query it with various prompts to see how it responds."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PVLXadptyo34"
      },
      "source": [
        "### Europe trip prompt\n",
        "\n",
        "Query the model for suggestions on what to do on a trip to Europe."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZwQz3xxxKciD"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Instruction:\n",
            "What should I do on a trip to Europe?\n",
            "\n",
            "Response:\n",
            "The first thing to know is that you will have a great time!\n",
            "\n",
            "Europe is a great place for a vacation. The countries of Europe are all very different and offer a wide range of activities and attractions. The countries of Europe are also very close to each other, which means you can visit many different places within a short time.\n",
            "\n",
            "The best way to plan a trip to Europe is to look up the countries you want to visit and see what activities are offered in each country. You can also look for tours and tours that offer a good value for money.\n",
            "\n",
            "You can also look for hotels and flights that offer good deals. If you are looking for a good value for money, you should look for hotels and flights that offer good deals. This means you will have a great time on your trip!\n",
            "\n",
            "The next step is to book your tickets to the countries you want to visit. If you are planning to visit many countries, it's a good idea to book your tickets early. This means you’ll be able to get the best deal and avoid the long queues.\n",
            "\n",
            "The next step is to plan your itinerary. You can use a travel guide to plan your itinerary\n"
          ]
        }
      ],
      "source": [
        "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n",
        "\n",
        "prompt = template.format(\n",
        "    instruction=\"What should I do on a trip to Europe?\",\n",
        "    response=\"\",\n",
        ")\n",
        "sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
        "gemma_lm.compile(sampler=sampler)\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AePQUIs2h-Ks"
      },
      "source": [
        "The model responds with generic tips on how to plan a trip."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YQ74Zz_S0iVv"
      },
      "source": [
        "### Photosynthesis prompt\n",
        "\n",
        "Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lorJMbsusgoo"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Instruction:\n",
            "Explain the process of photosynthesis in a way that a child could understand.\n",
            "\n",
            "Response:\n",
            "Photosynthesis is a biological process that occurs in plants, algae, and some other organisms. In the process, light energy is captured and converted into the energy stored in the bonds of organic molecules. The process is crucial for life on Earth because it enables plants to use carbon dioxide and water to produce glucose and oxygen, which are essential for all living things.\n",
            "The process involves several stages:\n",
            "1. Light Reactions: Light energy is absorbed by pigments in the chloroplasts of the plant, converting it into chemical energy in the form of ATP and reducing power.\n",
            "2. Carbon Fixation: During this stage, carbon dioxide is combined with hydrogen to form organic molecules such as starch or glucose, which are used as a source of energy.\n",
            "3. Calvin Cycle: The process of carbon fixation occurs in the stroma of the chloroplasts. It involves the capture and reduction of carbon dioxide, producing glucose and reducing power in the form of ATP and NADPH molecules.\n",
            "4. Stroma: The stroma is the fluid-filled space where the light reactions occur in the chloroplasts.\n",
            "5. Chloroplasts: The chloroplasts contain the green pigments that absorb\n"
          ]
        }
      ],
      "source": [
        "prompt = template.format(\n",
        "    instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
        "    response=\"\",\n",
        ")\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WBQieduRizZf"
      },
      "source": [
        "The model response contains words that might not be easy to understand for a child such as chlorophyll."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pt7Nr6a7tItO"
      },
      "source": [
        "## LoRA fine-tuning\n",
        "\n",
        "This section shows you how to do fine-tuning using the Low Rank Adaptation (LoRA) tuning technique. This approach allows you to change the behavior of Gemma models using fewer compute resources."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9T7xe_jzslv4"
      },
      "source": [
        "### Load dataset\n",
        "\n",
        "Prepare a dataset for tuning by downloading an existing data set and formatting if for use with the the Keras `fit()` fine-tuning method. This tutorial uses the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for fine-tuning. The dataset contains 15,000 high-quality human-generated prompt and response pairs specifically designed for tuning generative models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xRaNCPUXKoa7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2025-04-10 20:48:49--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n",
            "Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.114, 3.163.189.74, ...\n",
            "Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
            "--2025-04-10 20:48:49--  https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1744321729&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NDMyMTcyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=vh0VIGB-UkK57FSfRikYCREpKuHt%7EnDKPcHHgC1V9rDXLABIRF81nK7olQhAq6zSbAqEtMNnvHgd8IBK1j54mdIYdVLiBwImqez3xu2CPhzYBtKWInnXj9lTXW0p-9GEHcbU%7Eoot22qFSdwyZf1UIdmHZLTHPWjtLhfRkKbg-ptA3CFeegtmvCtY-WG2GffJ%7Em2q2bbs-U1m0yI7cSTW18nD8VSBihxGOMnS1IhkO-LgE4I6GJISXROTk-61%7EJiEIKcagcijL4QGi8j1g9xeQamBXX4hWBdkbJgX5PtX15Ftd0HCM4zCzcJAUrE3ZEJRLe2XRUwfKU3ai7-%7ErPpnSA__&Key-Pair-Id=K3RPWS32NSSJCE\n",
            "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.238.217.63, 18.238.217.81, 18.238.217.120, ...\n",
            "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.238.217.63|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 13085339 (12M) [text/plain]\n",
            "Saving to: ‘databricks-dolly-15k.jsonl’\n",
            "\n",
            "databricks-dolly-15 100%[===================>]  12.48M  --.-KB/s    in 0.08s   \n",
            "\n",
            "2025-04-10 20:48:49 (156 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n",
            "\n"
          ]
        }
      ],
      "source": [
        "!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45UpBDfBgf0I"
      },
      "source": [
        "### Format tuning data\n",
        "\n",
        "Format the downloaded data for use with the Keras `fit()` method. The following code extracts a subset of the training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZiS-KU9osh_N"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "prompts = []\n",
        "responses = []\n",
        "line_count = 0\n",
        "\n",
        "with open(\"databricks-dolly-15k.jsonl\") as file:\n",
        "    for line in file:\n",
        "        if line_count >= 1000:\n",
        "            break  # Limit the training examples, to reduce execution time.\n",
        "\n",
        "        examples = json.loads(line)\n",
        "        # Filter out examples with context, to keep it simple.\n",
        "        if examples[\"context\"]:\n",
        "            continue\n",
        "        # Format data into prompts and response lists.\n",
        "        prompts.append(examples[\"instruction\"])\n",
        "        responses.append(examples[\"response\"])\n",
        "\n",
        "        line_count += 1\n",
        "\n",
        "data = {\n",
        "    \"prompts\": prompts,\n",
        "    \"responses\": responses\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cBLW5hiGj31i"
      },
      "source": [
        "### Configure LoRA tuning\n",
        "\n",
        "Activate LoRA tuning using the Keras `model.backbone.enable_lora()` method, including a LoRA rank value. The *LoRA rank* determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments. A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.\n",
        "\n",
        "This example uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This setting is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RCucu6oHz53G"
      },
      "outputs": [],
      "source": [
        "# Enable LoRA for the model and set the LoRA rank to 4.\n",
        "gemma_lm.backbone.enable_lora(rank=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PlMLp_NVbRoQ"
      },
      "source": [
        "Check the model summary after setting the LoRA rank. Notice that enabling LoRA reduces the number of trainable parameters significantly compared to the total number of parameters in the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KqYyS0gm6pNy"
      },
      "outputs": [],
      "source": [
        "gemma_lm.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQQ47kcdpbZ9"
      },
      "source": [
        "Configure the rest of the fine-tuning settings, including the preprocessor settings, optimizer, number of tuning epochs, and batch size:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p9sBNH8SAjgB"
      },
      "outputs": [],
      "source": [
        "# Limit the input sequence length to 256 (to control memory usage).\n",
        "gemma_lm.preprocessor.sequence_length = 256\n",
        "# Use AdamW (a common optimizer for transformer models).\n",
        "optimizer = keras.optimizers.AdamW(\n",
        "    learning_rate=5e-5,\n",
        "    weight_decay=0.01,\n",
        ")\n",
        "# Exclude layernorm and bias terms from decay.\n",
        "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
        "\n",
        "gemma_lm.compile(\n",
        "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer=optimizer,\n",
        "    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OA0ozGC66tk1"
      },
      "source": [
        "### Run the fine-tune process\n",
        "\n",
        "Run the fine-tuning process using the `fit()` method. This process can take several minutes depending on your compute resources, data size, and number of epochs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Peq7TnLtHse"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m923s\u001b[0m 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<keras.src.callbacks.history.History at 0x799d04393c40>"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "gemma_lm.fit(data, epochs=1, batch_size=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bx3m8f1dB7nk"
      },
      "source": [
        "#### Mixed precision fine-tuning on NVIDIA GPUs\n",
        "\n",
        "Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T0lHxEDX03gp"
      },
      "outputs": [],
      "source": [
        "# Uncomment the line below if you want to enable mixed precision training on GPUs\n",
        "# keras.mixed_precision.set_global_policy('mixed_bfloat16')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4yd-1cNw1dTn"
      },
      "source": [
        "## Inference after fine-tuning\n",
        "\n",
        "After fine-tuning, you should see changes in the responses when the tuned model is given the same prompt."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H55JYJ1a1Kos"
      },
      "source": [
        "### Europe trip prompt\n",
        "\n",
        "Try the Europe trip prompt from earlier and note the differences in the response."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y7cDJHy8WfCB"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Instruction:\n",
            "What should I do on a trip to Europe?\n",
            "\n",
            "Response:\n",
            "When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.\n"
          ]
        }
      ],
      "source": [
        "prompt = template.format(\n",
        "    instruction=\"What should I do on a trip to Europe?\",\n",
        "    response=\"\",\n",
        ")\n",
        "sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)\n",
        "gemma_lm.compile(sampler=sampler)\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OXP6gg2mjs6u"
      },
      "source": [
        "The model now provides a shorter response to a question about visiting Europe."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H7nVd8Mi1Yta"
      },
      "source": [
        "### Photosynthesis prompt\n",
        "\n",
        "Try the photosynthesis explanation prompt from earlier and note the differences in the response."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X-2sYl2jqwl7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Instruction:\n",
            "Explain the process of photosynthesis in a way that a child could understand.\n",
            "\n",
            "Response:\n",
            "The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.\n"
          ]
        }
      ],
      "source": [
        "prompt = template.format(\n",
        "    instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n",
        "    response=\"\",\n",
        ")\n",
        "print(gemma_lm.generate(prompt, max_length=256))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PCmAmqrvkEhc"
      },
      "source": [
        "The model now explains photosynthesis in simpler terms."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I8kFG12l0mVe"
      },
      "source": [
        "## Improving fine-tune results\n",
        "\n",
        "For demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:\n",
        "\n",
        "1. Increasing the size of the fine-tuning dataset\n",
        "2. Training for more steps (epochs)\n",
        "3. Setting a higher LoRA rank\n",
        "4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gSsRdeiof_rJ"
      },
      "source": [
        "## Summary and next steps\n",
        "\n",
        "This tutorial covered LoRA fine-tuning on a Gemma model using Keras. Check out the following docs next:\n",
        "\n",
        "* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).\n",
        "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).\n",
        "* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).\n",
        "* Learn how to [fine-tune Gemma using Keras and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "lora_tuning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
